import copy
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict, Counter
from sklearn.feature_extraction import DictVectorizer
from concurrent.futures import ThreadPoolExecutor, as_completed

from dsl_design.src.cluster import DPMM
from utils.distribution import N_Gaussian_Distribution
from utils.util import read_json, write_json

class Production:
    def __init__(self, recognized_data_path, operation_dsl_path="", same_components={}):
        self.data_recognized = read_json(recognized_data_path)
        self.operation_dsl = {} if not operation_dsl_path else read_json(operation_dsl_path)
        self.same_components = same_components
        self.component_mapping = self.__unify()
        self.production_dsl = {}
        self.pattern_example = {
            "Pred": [],
            "FlowUnit": {
                "Component": "",
                "ComponentType": "",
                "Vol": [],
                "Container": [],
                "Cond": {}
            },
            "Succ": []
        }

# {
#     Pred: <Operation.UniqueName>,
#     FlowUnit: {
#         Component: "E.coli",
#         ComponentType: BiologicalMaterial,
#         RefName: <STR>,
#         UnitArgType: PROD,
#         Vol: (range),
#         Container: Flask | Tube,
#         Cond: {
#             Tempreature: (range)
#         }
#     },
#     Succ: <Operation.UniqueName>
# }
    def extract(self, production_dsl_store_path):
        num_flowunits = 0

        for protocol in tqdm(self.data_recognized):
            for idx, sentence in enumerate(protocol):
                try:
                    meta_flowunits = [metadata for metadata in (sentence["recognized"].get("input_flow_units", []) + sentence["recognized"].get("output_flow_units", [])) if metadata]
                except:
                    continue
                num_flowunits += len(meta_flowunits)

                for metadata in meta_flowunits:
                    if not metadata.get("Name") or metadata["Superclass"] == "NONE":
                        continue
                    name = self.component_mapping.get(metadata["Name"], metadata["Name"])
                    pattern = self.production_dsl.setdefault(name, copy.deepcopy(self.pattern_example))
                    
                    pattern["Pred"].append(sentence["operation"])
                    pattern["FlowUnit"]["Component"] = name
                    pattern["FlowUnit"]["ComponentType"] = metadata["Superclass"]
                    
                    if metadata["Volume"]:
                        pattern["FlowUnit"]["Vol"].append(metadata["Volume"])
                    if metadata["Container"]:
                        pattern["FlowUnit"]["Container"].append(metadata["Container"])
                    
                    if "Condition" in metadata and isinstance(metadata["Condition"], dict):
                        for argkey, argvalue in metadata["Condition"].items():
                            if argvalue:
                                pattern["FlowUnit"]["Cond"].setdefault(argkey, []).append(argvalue)
                        
                        if idx < len(protocol) - 1:
                            pattern["Succ"].append(protocol[idx + 1]["operation"])
        
        self.__abstraction()
        num_abstraction = len(self.production_dsl)
        write_json(production_dsl_store_path, self.production_dsl)
        print(num_flowunits)
        print(num_abstraction)
        print(format(num_abstraction / num_flowunits * 100, ".2f"))
    
    def __abstraction(self):
        to_delete = []
        for name, dsl in self.production_dsl.items():
            try:
                dsl["Pred"] = self.__most_frequent(dsl["Pred"])
                dsl["Succ"] = self.__most_frequent(dsl["Succ"])
                dsl["FlowUnit"]["Vol"] = self.__sorted_unique(dsl["FlowUnit"]["Vol"])
                dsl["FlowUnit"]["Container"] = self.__sorted_unique(dsl["FlowUnit"]["Container"])
                for argkey, argvalue_list in dsl["FlowUnit"]["Cond"].items():
                    dsl["FlowUnit"]["Cond"][argkey] = self.__sorted_unique(argvalue_list)
            except:
                to_delete.append(name)
        # Delete after the iteration
        for name in to_delete:
            del self.production_dsl[name]
    
    def __most_frequent(self, lst):
        lst = [ele for ele in lst if ele != "NONE"]
        if not lst:
            return ""
        count = Counter(lst)
        return count.most_common(1)[0][0]

    def __sorted_unique(self, lst):
        count = Counter(lst)
        return sorted(count.keys(), key=lambda x: count[x], reverse=True)

    def __unify(self):
        component_mapping = {original: unified for unified, original_list in self.same_components.items() for original in original_list}
        return component_mapping
    

    def operation_simple_clustering(self, domain):
        component_types = ["Gas", "Liquid", "Solid", "Semi-Solid", "Mixture", 
                        "Chemical Compound", "Biological Material", 
                        "Reagent", "Physical Object", "File/Data"]
        
        def cluster_iteration(i):
            vector_list = []
            curve = pd.DataFrame()

            # Generate the vector list based on operation_dsl
            for opcode, patterns in self.operation_dsl.items():
                max_len = 0
                selected_pattern = None
                for pattern in patterns:
                    if len(pattern["examples"]) > max_len:
                        max_len = len(pattern["examples"])
                        selected_pattern = pattern
                vector = np.zeros(len(component_types)*2)
                if selected_pattern["pattern"].get("Precond", {}).get("SlotArg", []):
                    vector[component_types.index(selected_pattern["pattern"]["Precond"]["SlotArg"][0])] = 1
                if selected_pattern["pattern"].get("Postcond", {}).get("EmitArg", []):
                    vector[component_types.index(selected_pattern["pattern"]["Postcond"]["EmitArg"][0]) + len(component_types)] = 1
                vector_list.append(vector)

            # Run the DPMM clustering
            result = DPMM.cluster(data=vector_list, Distribution=N_Gaussian_Distribution, 
                                feature_dim=len(vector_list[0]), iter_times=1000, 
                                alpha=0.1, regular=0.1)

            # Log the result
            print(f"Iteration {i} clusters: {result['K']}")
            curve[str(i)] = [float(num) for num in result["log_likelihood_list"].split()]
            return curve

        # Parallelize the clustering iterations
        curves = []
        with ThreadPoolExecutor(max_workers=4) as executor:  # Adjust max_workers based on the number of threads you want
            futures = {executor.submit(cluster_iteration, i): i for i in range(20)}
            for future in as_completed(futures):
                try:
                    curves.append(future.result())
                except Exception as exc:
                    print(f"Iteration {futures[future]} generated an exception: {exc}")

        # Combine and save all results
        combined_curve = pd.concat(curves, axis=1)
        combined_curve = self.__cope_curve(combined_curve)
        combined_curve.to_csv(f"dsl_design/data/cluster_curve/{domain}_product_curve.csv", index=False)
    
    def __cope_curve(self, data):
        # Calculate the average log_likelihood for each iteration
        data['log_likelihood'] = data.mean(axis=1)

        # Create a new DataFrame with 'iteration' and 'log_likelihood'
        result = pd.DataFrame({
            'iteration': data.index,  # Assuming iteration is the row index
            'log_likelihood': data['log_likelihood']
        })

        return result

    def operation_clustering(self, domain):
        curve = pd.DataFrame()
        operation_features = self.__operation_feature_extraction()
        feature_space = self.__create_feature_space(operation_features)
        operation_vectors = self.__encode_operations_merge(operation_features, feature_space)
        # operation_vectors = self.__encode_operations_select(operation_features, feature_space)

        vectorizer = DictVectorizer(sparse=False)
        X = vectorizer.fit_transform(operation_vectors)

        result = DPMM.cluster(data=X, Distribution=N_Gaussian_Distribution, feature_dim=len(feature_space), iter_times=1000, alpha=0.1, regular=0.1)
        print(result["K"])
        curve["iteration"] = list(range(1000))
        curve["log_likelihood"] = [float(num) for num in result["log_likelihood_list"].split()]
        curve.to_csv(f"dsl_design/data/cluster_curve/{domain}_product_curve.csv", index=False)
        clustered_operations = {}
        opcode_to_superclass = read_json("dsl_design/data/demo/opcode_to_superclass.json")
        operation_names = list(operation_features.keys())
        for i, name in enumerate(operation_names):
            ground_truth = opcode_to_superclass[name[0].upper() + name[1:].lower()]
            clustered_operations.setdefault(int(result["label"][i]), []).append(ground_truth)

        for cluster, ops in clustered_operations.items():
            print(f"Cluster {cluster}: {Counter(ops)}")      

    def __operation_feature_extraction(self):
        operation_features = {}
        for opcode, patterns in self.operation_dsl.items():
            if not patterns:
                continue

            opcode_features = []
            for pattern in patterns:
                feature = {
                    "Precond": pattern["pattern"]["Precond"].get("SlotArg", []), 
                    "Postcond": pattern["pattern"]["Postcond"].get("EmitArg", []), 
                    "Device": []
                }
                for device in pattern["pattern"]["Execution"]:
                    feature["Device"].append(device["DeviceType"])
                opcode_features.append(feature)
            operation_features[opcode] = opcode_features
        return operation_features
    
    def __create_feature_space(self, data):
        feature_space = set()
        
        for patterns in data.values():
            for pattern in patterns:
                precond_counts = {}
                postcond_counts = {}
                device_counts = {}

                for pre in pattern.get('Precond', []):
                    count = precond_counts.get(pre, 0) + 1
                    precond_counts[pre] = count
                    feature_space.add(f'Precond_{pre}_{count}')

                for post in pattern.get('Postcond', []):
                    count = postcond_counts.get(post, 0) + 1
                    postcond_counts[post] = count
                    feature_space.add(f'Postcond_{post}_{count}')

                for dev in pattern.get('Device', []):
                    count = device_counts.get(dev, 0) + 1
                    device_counts[dev] = count
                    feature_space.add(f'Device_{dev}_{count}')
        
        return sorted(list(feature_space))

    def __encode_operations_merge(self, data, feature_space):
        operation_vectors = []
        
        for operation, patterns in data.items():
            feature_vector = {feature: 0 for feature in feature_space}
            
            for pattern in patterns:
                precond_counts = {}
                postcond_counts = {}
                device_counts = {}

                for pre in pattern.get('Precond', []):
                    count = precond_counts.get(pre, 0) + 1
                    precond_counts[pre] = count
                    feature_vector[f'Precond_{pre}_{count}'] = 1

                for post in pattern.get('Postcond', []):
                    count = postcond_counts.get(post, 0) + 1
                    postcond_counts[post] = count
                    feature_vector[f'Postcond_{post}_{count}'] = 1

                for dev in pattern.get('Device', []):
                    count = device_counts.get(dev, 0) + 1
                    device_counts[dev] = count
                    feature_vector[f'Device_{dev}_{count}'] = 1
            
            operation_vectors.append(feature_vector)
        
        return operation_vectors
    
    def __encode_operations_select(self, data, feature_space):
        operation_vectors = []
        operation_names = []
        
        for operation, patterns in data.items():

            def count_parameters(pattern):
                return len(pattern.get("Precond", [])) + len(pattern.get("Postcond", [])) + len(pattern.get("Device", []))
            
            selected_pattern = max(patterns, key=count_parameters)

            feature_vector = {feature: 0 for feature in feature_space}
            
            precond_counts = {}
            postcond_counts = {}
            device_counts = {}

            for pre in selected_pattern.get('Precond', []):
                count = precond_counts.get(pre, 0) + 1
                precond_counts[pre] = count
                feature_vector[f'Precond_{pre}_{count}'] = 1

            for post in selected_pattern.get('Postcond', []):
                count = postcond_counts.get(post, 0) + 1
                postcond_counts[post] = count
                feature_vector[f'Postcond_{post}_{count}'] = 1

            for dev in selected_pattern.get('Device', []):
                count = device_counts.get(dev, 0) + 1
                device_counts[dev] = count
                feature_vector[f'Device_{dev}_{count}'] = 1
            
            operation_vectors.append(feature_vector)
        
        return operation_vectors

    def __convert_patterns_to_features(self, patterns):
        pattern_dicts = []
        for pattern in patterns:
            pattern_dict = defaultdict(int)
            for pre in pattern.get("Precond", []):
                pattern_dict[f"Precond_{pre}"] += 1
            for post in pattern.get("Postcond", []):
                pattern_dict[f"Postcond_{post}"] += 1
            for dev in pattern.get("Device", []):
                pattern_dict[f"Device_{dev}"] += 1
            pattern_dicts.append(dict(pattern_dict))
        return pattern_dicts